Fitting B time series (animation)ΒΆ

This example demonstrates how to fit generic B observation inputs and fit an SECS system to make predictions on a separate grid and compare the results.

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

from pysecs import SECS


R_earth = 6371e3

# specify the SECS grid
lat, lon, r = np.meshgrid(np.linspace(-20, 20, 30),
                          np.linspace(-20, 20, 30),
                          R_earth + 110000, indexing='ij')
secs_lat_lon_r = np.hstack((lat.reshape(-1, 1),
                            lon.reshape(-1, 1),
                            r.reshape(-1, 1)))

# Set up the class
secs = SECS(sec_df_loc=secs_lat_lon_r)

# Make a grid of input observations spanning
# (-10, 10) in latitutde and longitude
lat, lon, r = np.meshgrid(np.linspace(-10, 10, 11),
                          np.linspace(-10, 10, 11),
                          R_earth, indexing='ij')
obs_lat = lat[..., 0]
obs_lon = lon[..., 0]
obs_lat_lon_r = np.hstack((lat.reshape(-1, 1),
                           lon.reshape(-1, 1),
                           r.reshape(-1, 1)))
nobs = len(obs_lat_lon_r)

# Create the synthetic magnetic field data as a function
# of time
ts = np.linspace(0, 2*np.pi)
bx = 5*np.cos(ts)
by = 5*np.sin(ts)
bz = ts
# ntimes x 3
B_obs = np.column_stack([bx, by, bz])
ntimes = len(B_obs)

# Repeat that for each observatory
# ntimes x nobs x 3
B_obs = np.repeat(B_obs[:, np.newaxis, :], nobs, axis=1)
# Make it more interesting and add a sin wave in spatial
# coordinates too
B_obs[:, :, 0] *= 2*np.sin(np.deg2rad(obs_lat_lon_r[:, 0]))
B_obs[:, :, 1] *= 2*np.sin(np.deg2rad(obs_lat_lon_r[:, 1]))

B_var = np.ones(B_obs.shape)
# Ignore the Z component
B_var[..., 2] = np.inf
# Can modify the variance as a function of time to
# see how that changes the fits too
# B_var[:, 0, 1] = 1 + ts

# Fit the data, requires observation locations and data
secs.fit(obs_loc=obs_lat_lon_r, obs_B=B_obs, obs_var=B_var)

# Create prediction points
# Extend it a little beyond the observation points (-11, 11)
lat, lon, r = np.meshgrid(np.linspace(-11, 11, 11),
                          np.linspace(-11, 11, 11),
                          R_earth, indexing='ij')
pred_lat_lon_r = np.hstack((lat.reshape(-1, 1),
                            lon.reshape(-1, 1),
                            r.reshape(-1, 1)))

# Call the prediction function
B_pred = secs.predict(pred_lat_lon_r)

# Now set up the plots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2,
                                             sharex=True, sharey=True)
vmin, vmax = -5, 5
cmap = plt.get_cmap('RdBu_r')
t = 10

mesh1 = ax1.pcolormesh(obs_lon, obs_lat, B_obs[t, :, 0].reshape(obs_lon.shape),
                       vmin=vmin, vmax=vmax, cmap=cmap, shading='auto')
mesh2 = ax2.pcolormesh(obs_lon, obs_lat, B_obs[t, :, 1].reshape(obs_lon.shape),
                       vmin=vmin, vmax=vmax, cmap=cmap, shading='auto')
ax1.set_ylabel("Observations")
ax1.set_title("B$_{X}$")
ax2.set_title("B$_{Y}$")

qscale = 1
q1 = ax1.quiver(obs_lat_lon_r[:, 1], obs_lat_lon_r[:, 0],
                B_obs[t, :, 1], B_obs[t, :, 0],
                angles='xy', scale_units='xy', scale=qscale)
q2 = ax2.quiver(obs_lat_lon_r[:, 1], obs_lat_lon_r[:, 0],
                B_obs[t, :, 1], B_obs[t, :, 0],
                angles='xy', scale_units='xy', scale=qscale)

lon = lon[..., 0]
lat = lat[..., 0]
mesh3 = ax3.pcolormesh(lon, lat, B_pred[t, :, 0].reshape(lon.shape),
                       vmin=vmin, vmax=vmax, cmap=cmap, shading='auto')
mesh4 = ax4.pcolormesh(lon, lat, B_pred[t, :, 1].reshape(lon.shape),
                       vmin=vmin, vmax=vmax, cmap=cmap, shading='auto')

q3 = ax3.quiver(pred_lat_lon_r[:, 1], pred_lat_lon_r[:, 0],
                B_pred[t, :, 1], B_pred[t, :, 0],
                angles='xy', scale_units='xy', scale=qscale)
q4 = ax4.quiver(pred_lat_lon_r[:, 1], pred_lat_lon_r[:, 0],
                B_pred[t, :, 1], B_pred[t, :, 0],
                angles='xy', scale_units='xy', scale=qscale)
ax3.set_ylabel("Predictions")
ax3.set_title("B$_{X}$")
ax4.set_title("B$_{Y}$")


def update_axes(t):
    # Update the mesh colors
    mesh1.set_array(B_obs[t, :, 0].reshape(obs_lon.shape))
    mesh2.set_array(B_obs[t, :, 1].reshape(obs_lon.shape))
    mesh3.set_array(B_pred[t, :, 0].reshape(lon.shape))
    mesh4.set_array(B_pred[t, :, 1].reshape(lon.shape))

    # Update the quiver arrows
    q1.set_UVC(B_obs[t, :, 1], B_obs[t, :, 0])
    q2.set_UVC(B_obs[t, :, 1], B_obs[t, :, 0])
    q3.set_UVC(B_pred[t, :, 1], B_pred[t, :, 0])
    q4.set_UVC(B_pred[t, :, 1], B_pred[t, :, 0])


ani = FuncAnimation(fig, update_axes, frames=range(ntimes),
                    interval=50)

plt.show()

Total running time of the script: ( 0 minutes 19.069 seconds)

Gallery generated by Sphinx-Gallery